4ad4ce
@@ -18,6 +18,7 @@
 
 package org.apache.hadoop.hive.ql.parse.spark;
 
+import java.util.ArrayList;
 import java.util.HashSet;
 import java.util.LinkedList;
 import java.util.List;
@@ -36,63 +37,72 @@
 import org.apache.hadoop.hive.ql.lib.NodeProcessorCtx;
 import org.apache.hadoop.hive.ql.parse.SemanticException;
 
-import com.google.common.base.Preconditions;
 
 /**
  * This processor triggers on SparkPartitionPruningSinkOperator. For a operator tree like
  * this:
  *
  * Original Tree:
- *     TS    TS
- *      |     |
- *     FIL   FIL
- *      |     | \
- *     RS     RS SEL
- *       \   /    |
- *        JOIN   GBY
- *                |
- *               SPARKPRUNINGSINK
+ *     TS1       TS2
+ *      |          |
+ *      FIL       FIL
+ *      |          |
+ *      RS      /   \   \
+ *      |      |    \    \
+ *      |     RS  SEL  SEL
+ *      \   /      |     |
+ *      JOIN      GBY   GBY
+ *                  |    |
+ *                  |  SPARKPRUNINGSINK
+ *                  |
+ *              SPARKPRUNINGSINK
  *
  * It removes the branch containing SPARKPRUNINGSINK from the original operator tree, and splits it into
  * two separate trees:
- *
- * Tree #1:                 Tree #2:
- *     TS    TS               TS
- *      |     |                |
- *     FIL   FIL              FIL
- *      |     |                |
- *     RS     RS              SEL
- *       \   /                 |
- *       JOIN                 GBY
- *                             |
- *                            SPARKPRUNINGSINK
- *
+ * Tree #1:                       Tree #2
+ *      TS1    TS2                 TS2
+ *      |      |                    |
+ *      FIL    FIL                 FIL
+ *      |       |                   |_____
+ *      RS     SEL                  |     \
+ *      |       |                   SEL    SEL
+ *      |     RS                    |      |
+ *      \   /                       GBY    GBY
+ *      JOIN                        |      |
+ *                                  |    SPARKPRUNINGSINK
+ *                                 SPARKPRUNINGSINK
+
  * For MapJoinOperator, this optimizer will not do anything - it should be executed within
  * the same SparkTask.
  */
 public class SplitOpTreeForDPP implements NodeProcessor {
+
   @Override
   public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx,
                         Object... nodeOutputs) throws SemanticException {
     SparkPartitionPruningSinkOperator pruningSinkOp = (SparkPartitionPruningSinkOperator) nd;
     GenSparkProcContext context = (GenSparkProcContext) procCtx;
 
+    for (Operator<?> op : context.pruningSinkSet) {
+      if (pruningSinkOp.getOperatorId().equals(op.getOperatorId())) {
+        return null;
+      }
+    }
+
     // Locate the op where the branch starts
     // This is guaranteed to succeed since the branch always follow the pattern
     // as shown in the first picture above.
-    Operator<?> filterOp = pruningSinkOp;
-    Operator<?> selOp = null;
-    while (filterOp != null) {
-      if (filterOp.getNumChild() > 1) {
+    Operator<?> branchingOp = pruningSinkOp;
+    while (branchingOp != null) {
+      if (branchingOp.getNumChild() > 1) {
         break;
       } else {
-        selOp = filterOp;
-        filterOp = filterOp.getParentOperators().get(0);
+        branchingOp = branchingOp.getParentOperators().get(0);
       }
     }
 
     // Check if this is a MapJoin. If so, do not split.
-    for (Operator<?> childOp : filterOp.getChildOperators()) {
+    for (Operator<?> childOp : branchingOp.getChildOperators()) {
       if (childOp instanceof ReduceSinkOperator &&
           childOp.getChildOperators().get(0) instanceof MapJoinOperator) {
         context.pruningSinkSet.add(pruningSinkOp);
@@ -103,8 +113,10 @@
public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx,
     List<Operator<?>> roots = new LinkedList<Operator<?>>();
     collectRoots(roots, pruningSinkOp);
 
-    List<Operator<?>> savedChildOps = filterOp.getChildOperators();
-    filterOp.setChildOperators(Utilities.makeList(selOp));
+    List<Operator<?>> savedChildOps = branchingOp.getChildOperators();
+    List<Operator<?>> firstNodesOfPruningBranch = findFirstNodesOfPruningBranch(branchingOp);
+    branchingOp.setChildOperators(Utilities.makeList(firstNodesOfPruningBranch.toArray(new
+        Operator<?>[firstNodesOfPruningBranch.size()])));
 
     // Now clone the tree above selOp
     List<Operator<?>> newRoots = SerializationUtilities.cloneOperatorTree(roots);
@@ -115,27 +127,49 @@
public Object process(Node nd, Stack<Node> stack, NodeProcessorCtx procCtx,
     }
     context.clonedPruningTableScanSet.addAll(newRoots);
 
+    //Find all pruningSinkSet in old roots
+    List<Operator<?>> oldsinkList = new ArrayList<>();
+    for (Operator<?> root : roots) {
+      SparkUtilities.collectOp(oldsinkList, root, SparkPartitionPruningSinkOperator.class);
+    }
+
     // Restore broken links between operators, and remove the branch from the original tree
-    filterOp.setChildOperators(savedChildOps);
-    filterOp.removeChild(selOp);
+    branchingOp.setChildOperators(savedChildOps);
+    for (Operator selOp : firstNodesOfPruningBranch) {
+      branchingOp.removeChild(selOp);
+    }
 
-    // Find the cloned PruningSink and add it to pruningSinkSet
-    Set<Operator<?>> sinkSet = new HashSet<Operator<?>>();
+    //Find all pruningSinkSet in new roots
+    Set<Operator<?>> sinkSet = new HashSet<>();
     for (Operator<?> root : newRoots) {
       SparkUtilities.collectOp(sinkSet, root, SparkPartitionPruningSinkOperator.class);
     }
-    Preconditions.checkArgument(sinkSet.size() == 1,
-        "AssertionError: expected to only contain one SparkPartitionPruningSinkOperator," +
-            " but found " + sinkSet.size());
-    SparkPartitionPruningSinkOperator clonedPruningSinkOp =
-        (SparkPartitionPruningSinkOperator) sinkSet.iterator().next();
-    clonedPruningSinkOp.getConf().setTableScan(pruningSinkOp.getConf().getTableScan());
-    context.pruningSinkSet.add(clonedPruningSinkOp);
 
+    int i = 0;
+    for (Operator<?> clonedPruningSinkOp : sinkSet) {
+      SparkPartitionPruningSinkOperator oldsinkOp = (SparkPartitionPruningSinkOperator) oldsinkList.get(i++);
+      ((SparkPartitionPruningSinkOperator) clonedPruningSinkOp).getConf().setTableScan(oldsinkOp.getConf().getTableScan());
+      context.pruningSinkSet.add(clonedPruningSinkOp);
+
+    }
     return null;
   }
 
-  /**
+  //find operators which are the children of specified filterOp and there are SparkPartitionPruningSink in these
+  //branches.
+  private List<Operator<?>> findFirstNodesOfPruningBranch(Operator<?> branchingOp) {
+    List<Operator<?>> res = new ArrayList<>();
+    for (Operator child : branchingOp.getChildOperators()) {
+      List<Operator<?>> pruningList = new ArrayList<>();
+      SparkUtilities.collectOp(pruningList, child, SparkPartitionPruningSinkOperator.class);
+      if (pruningList.size() > 0) {
+        res.add(child);
+      }
+    }
+    return res;
+  }
+
+    /**
    * Recursively collect all roots (e.g., table scans) that can be reached via this op.
    * @param result contains all roots can be reached via op
    * @param op the op to examine.
